Latent space encoding using LSTMs: Finding similar word context

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics.pairwise import euclidean_distances

import numpy as np
import pandas as pd
In [2]:
maxlen = 3
max_features = 1000

Let's prepare the encoding that the Keras dataloader uses, so we can encode input, and reverse the output:

In [3]:
word_to_id = keras.datasets.imdb.get_word_index()
word_to_id = {k:(v+3) for k,v in word_to_id.items()}
word_to_id["<PAD>"] = 0
word_to_id["<START>"] = 1
word_to_id["<UNK>"] = 2
word_to_id["<UNUSED>"] = 3
id_to_word = {value:key for key,value in word_to_id.items()}

Get the data

We load the data nad preprocess it so the LSTMs can process it. This also handles the padding, in case a review is shorter than the defined sequence length.

In [4]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=max_features)
In [5]:
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen=maxlen)
In [6]:
x_train.shape
Out[6]:
(25000, 3)
In [7]:
x_train[0]
Out[7]:
array([ 19, 178,  32], dtype=int32)
In [8]:
all = []
for it in range(x_train.shape[0]):
    row = np.zeros((maxlen, max_features))
    for jt in range(maxlen):
        row[jt, x_train[it, jt]] = 1
    all.append(row * 1.0)
In [9]:
data_enc = np.array(all)
In [10]:
data_enc.shape
Out[10]:
(25000, 3, 1000)
In [11]:
np.argmax(data_enc[0], axis=1).tolist()
Out[11]:
[19, 178, 32]

Build the model

In [12]:
inputs = keras.Input(shape=(None,), dtype="int32")
x = layers.Embedding(max_features, 128)(inputs)
x = layers.Bidirectional(layers.LSTM(128))(x)
x = layers.BatchNormalization()(x)
encoded = layers.Dense(3)(x)
In [18]:
x = layers.Dense(3)(layers.RepeatVector(maxlen)(encoded))
x = layers.BatchNormalization()(x)
x = layers.Bidirectional(layers.LSTM(128, return_sequences=True))(x)
decoded = layers.TimeDistributed(layers.Dense(max_features))(x)
decoded = layers.Softmax(name="decoded")(decoded)
sentiment = layers.Dense(1)(x[:, -1])
sentiment = layers.Activation('sigmoid', name="sentiment")(sentiment)
In [20]:
model = keras.Model(inputs, [decoded, sentiment])
model.summary()
Model: "functional_7"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, None)]       0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, None, 128)    128000      input_1[0][0]                    
__________________________________________________________________________________________________
bidirectional (Bidirectional)   (None, 256)          263168      embedding[0][0]                  
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 256)          1024        bidirectional[0][0]              
__________________________________________________________________________________________________
dense (Dense)                   (None, 3)            771         batch_normalization[0][0]        
__________________________________________________________________________________________________
repeat_vector_3 (RepeatVector)  (None, 3, 3)         0           dense[0][0]                      
__________________________________________________________________________________________________
dense_8 (Dense)                 (None, 3, 3)         12          repeat_vector_3[0][0]            
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 3, 3)         12          dense_8[0][0]                    
__________________________________________________________________________________________________
bidirectional_4 (Bidirectional) (None, 3, 256)       135168      batch_normalization_4[0][0]      
__________________________________________________________________________________________________
tf_op_layer_strided_slice_1 (Te [(None, 256)]        0           bidirectional_4[0][0]            
__________________________________________________________________________________________________
time_distributed_3 (TimeDistrib (None, 3, 1000)      257000      bidirectional_4[0][0]            
__________________________________________________________________________________________________
dense_10 (Dense)                (None, 1)            257         tf_op_layer_strided_slice_1[0][0]
__________________________________________________________________________________________________
decoded (Softmax)               (None, 3, 1000)      0           time_distributed_3[0][0]         
__________________________________________________________________________________________________
sentiment (Activation)          (None, 1)            0           dense_10[0][0]                   
==================================================================================================
Total params: 785,412
Trainable params: 784,894
Non-trainable params: 518
__________________________________________________________________________________________________
In [21]:
encoder = keras.Model(inputs, encoded)
decoder = keras.Model(inputs, decoded)
In [22]:
model.compile(optimizer='adam', loss={'decoded': 'categorical_crossentropy', 'sentiment': 'binary_crossentropy'})
In [23]:
model.fit(x_train, {'decoded': data_enc, 'sentiment': y_train}, epochs=50)
Epoch 1/50
782/782 [==============================] - 7s 9ms/step - loss: 4.8706 - decoded_loss: 4.1762 - sentiment_loss: 0.6944
Epoch 2/50
782/782 [==============================] - 7s 9ms/step - loss: 4.2768 - decoded_loss: 3.5827 - sentiment_loss: 0.6940
Epoch 3/50
782/782 [==============================] - 7s 9ms/step - loss: 3.9768 - decoded_loss: 3.2835 - sentiment_loss: 0.6933
Epoch 4/50
782/782 [==============================] - 7s 9ms/step - loss: 3.7059 - decoded_loss: 3.0140 - sentiment_loss: 0.6919
Epoch 5/50
782/782 [==============================] - 8s 10ms/step - loss: 3.5577 - decoded_loss: 2.8696 - sentiment_loss: 0.6881
Epoch 6/50
782/782 [==============================] - 8s 10ms/step - loss: 3.4557 - decoded_loss: 2.7705 - sentiment_loss: 0.6852
Epoch 7/50
782/782 [==============================] - 8s 10ms/step - loss: 3.3726 - decoded_loss: 2.6928 - sentiment_loss: 0.6798
Epoch 8/50
782/782 [==============================] - 8s 11ms/step - loss: 3.3032 - decoded_loss: 2.6284 - sentiment_loss: 0.6748
Epoch 9/50
782/782 [==============================] - 8s 10ms/step - loss: 3.2301 - decoded_loss: 2.5580 - sentiment_loss: 0.6721
Epoch 10/50
782/782 [==============================] - 9s 12ms/step - loss: 3.1818 - decoded_loss: 2.5158 - sentiment_loss: 0.6661
Epoch 11/50
782/782 [==============================] - 15s 19ms/step - loss: 3.1239 - decoded_loss: 2.4614 - sentiment_loss: 0.6625
Epoch 12/50
782/782 [==============================] - 17s 22ms/step - loss: 3.0868 - decoded_loss: 2.4265 - sentiment_loss: 0.6604
Epoch 13/50
782/782 [==============================] - 12s 16ms/step - loss: 3.0491 - decoded_loss: 2.3899 - sentiment_loss: 0.6592
Epoch 14/50
782/782 [==============================] - 10s 13ms/step - loss: 3.0184 - decoded_loss: 2.3606 - sentiment_loss: 0.6577
Epoch 15/50
782/782 [==============================] - 9s 11ms/step - loss: 2.9864 - decoded_loss: 2.3336 - sentiment_loss: 0.6529
Epoch 16/50
782/782 [==============================] - 9s 11ms/step - loss: 2.9438 - decoded_loss: 2.2922 - sentiment_loss: 0.6516
Epoch 17/50
782/782 [==============================] - 8s 11ms/step - loss: 2.9173 - decoded_loss: 2.2690 - sentiment_loss: 0.6483
Epoch 18/50
782/782 [==============================] - 8s 11ms/step - loss: 2.8806 - decoded_loss: 2.2384 - sentiment_loss: 0.6422
Epoch 19/50
782/782 [==============================] - 9s 11ms/step - loss: 2.8647 - decoded_loss: 2.2247 - sentiment_loss: 0.6399
Epoch 20/50
782/782 [==============================] - 9s 12ms/step - loss: 2.8456 - decoded_loss: 2.2073 - sentiment_loss: 0.6383
Epoch 21/50
782/782 [==============================] - 9s 12ms/step - loss: 2.7910 - decoded_loss: 2.1554 - sentiment_loss: 0.6355
Epoch 22/50
782/782 [==============================] - 9s 12ms/step - loss: 2.7914 - decoded_loss: 2.1585 - sentiment_loss: 0.6329
Epoch 23/50
782/782 [==============================] - 9s 12ms/step - loss: 2.7636 - decoded_loss: 2.1344 - sentiment_loss: 0.6292
Epoch 24/50
782/782 [==============================] - 9s 12ms/step - loss: 2.7376 - decoded_loss: 2.1125 - sentiment_loss: 0.6251
Epoch 25/50
782/782 [==============================] - 10s 12ms/step - loss: 2.7454 - decoded_loss: 2.1203 - sentiment_loss: 0.6251
Epoch 26/50
782/782 [==============================] - 10s 12ms/step - loss: 2.6833 - decoded_loss: 2.0611 - sentiment_loss: 0.6222
Epoch 27/50
782/782 [==============================] - 9s 12ms/step - loss: 2.6825 - decoded_loss: 2.0621 - sentiment_loss: 0.6204
Epoch 28/50
782/782 [==============================] - 8s 10ms/step - loss: 2.6815 - decoded_loss: 2.0641 - sentiment_loss: 0.6174
Epoch 29/50
782/782 [==============================] - 8s 10ms/step - loss: 2.6578 - decoded_loss: 2.0430 - sentiment_loss: 0.6148
Epoch 30/50
782/782 [==============================] - 8s 10ms/step - loss: 2.6317 - decoded_loss: 2.0197 - sentiment_loss: 0.6120
Epoch 31/50
782/782 [==============================] - 8s 10ms/step - loss: 2.6220 - decoded_loss: 2.0137 - sentiment_loss: 0.6083
Epoch 32/50
782/782 [==============================] - 8s 10ms/step - loss: 2.6048 - decoded_loss: 2.0014 - sentiment_loss: 0.6034
Epoch 33/50
782/782 [==============================] - 8s 10ms/step - loss: 2.5578 - decoded_loss: 1.9582 - sentiment_loss: 0.5996
Epoch 34/50
782/782 [==============================] - 7s 9ms/step - loss: 2.5668 - decoded_loss: 1.9664 - sentiment_loss: 0.6005
Epoch 35/50
782/782 [==============================] - 7s 9ms/step - loss: 2.5629 - decoded_loss: 1.9652 - sentiment_loss: 0.5977
Epoch 36/50
782/782 [==============================] - 7s 9ms/step - loss: 2.5388 - decoded_loss: 1.9407 - sentiment_loss: 0.5981
Epoch 37/50
782/782 [==============================] - 7s 9ms/step - loss: 2.5197 - decoded_loss: 1.9281 - sentiment_loss: 0.5915
Epoch 38/50
782/782 [==============================] - 7s 10ms/step - loss: 2.5065 - decoded_loss: 1.9206 - sentiment_loss: 0.5859
Epoch 39/50
782/782 [==============================] - 8s 10ms/step - loss: 2.5238 - decoded_loss: 1.9346 - sentiment_loss: 0.5892
Epoch 40/50
782/782 [==============================] - 8s 10ms/step - loss: 2.4922 - decoded_loss: 1.9108 - sentiment_loss: 0.5814
Epoch 41/50
782/782 [==============================] - 7s 10ms/step - loss: 2.4684 - decoded_loss: 1.8867 - sentiment_loss: 0.5817
Epoch 42/50
782/782 [==============================] - 7s 9ms/step - loss: 2.4771 - decoded_loss: 1.8968 - sentiment_loss: 0.5804
Epoch 43/50
782/782 [==============================] - 7s 9ms/step - loss: 2.5007 - decoded_loss: 1.9183 - sentiment_loss: 0.5824
Epoch 44/50
782/782 [==============================] - 7s 9ms/step - loss: 2.4622 - decoded_loss: 1.8891 - sentiment_loss: 0.5731
Epoch 45/50
782/782 [==============================] - 7s 9ms/step - loss: 2.4273 - decoded_loss: 1.8521 - sentiment_loss: 0.5751
Epoch 46/50
782/782 [==============================] - 8s 10ms/step - loss: 2.4409 - decoded_loss: 1.8651 - sentiment_loss: 0.5758
Epoch 47/50
782/782 [==============================] - 8s 10ms/step - loss: 2.4199 - decoded_loss: 1.8503 - sentiment_loss: 0.5696
Epoch 48/50
782/782 [==============================] - 8s 10ms/step - loss: 2.4055 - decoded_loss: 1.8381 - sentiment_loss: 0.5674
Epoch 49/50
782/782 [==============================] - 7s 9ms/step - loss: 2.3938 - decoded_loss: 1.8305 - sentiment_loss: 0.5633
Epoch 50/50
782/782 [==============================] - 7s 9ms/step - loss: 2.4156 - decoded_loss: 1.8487 - sentiment_loss: 0.5668
Out[23]:
<tensorflow.python.keras.callbacks.History at 0x7ffcb49a3950>
In [27]:
res = model.predict(x_test)
In [29]:
res[0].shape, res[1].shape
Out[29]:
((25000, 3, 1000), (25000, 1))

Diplaying the latent space

In [32]:
import plotly.graph_objects as go
In [33]:
enc = encoder.predict(x_train)
In [34]:
enc[0]
Out[34]:
array([ 7.3765574,  2.6015177, -4.1122513], dtype=float32)

Display by sentiment

In [35]:
plot_data = [[], []]
In [36]:
for it in range(y_train.shape[0]):
    plot_data[y_train[it]].append(enc[it])
In [63]:
fig = go.Figure([
    go.Scatter3d(
        x=np.array(plot_data[0])[:, 0],
        y=np.array(plot_data[0])[:, 1],
        z=np.array(plot_data[0])[:, 2],
        mode='markers',
        marker={'size': 1},
    ),
    go.Scatter3d(
        x=np.array(plot_data[1])[:, 0],
        y=np.array(plot_data[1])[:, 1],
        z=np.array(plot_data[1])[:, 2],
        mode='markers',
        marker={'size': 1},
    )
])
fig.write_html('plot_sentiment.html')
In [38]:
fig.show()

See the attached HTML file to explore the plot. When exporting Plotly usually does not work anymore.

Let's test it out...

First we create the input sequence that we want to run the model against.

In [64]:
testing = [word_to_id[word] for word in ['best', 'movie', 'ever']]
In [65]:
testing
Out[65]:
[118, 20, 126]

Let's run the encoder and see where our sequence falls in the latent space.

In [66]:
testing_space = encoder.predict(np.array([testing]))
In [67]:
testing_space
Out[67]:
array([[13.41354  ,  6.775359 ,  4.5145726]], dtype=float32)

We can use the euclidean distance to figure out which dataset in the latent space is closest to what we just used as an input sequence. After we sort the array by its distances we should see the closest ones appear at the top:

In [68]:
full_distances = euclidean_distances(np.array([testing_space[0]]), enc)
In [69]:
distances = np.stack(
    [
        full_distances.reshape(25000, 1),
        np.array([[it] for it in range(25000)]),
    ],
    axis=1
).reshape(25000, 2).tolist()
distances = sorted(distances, key=lambda x: x[0])
In [70]:
distances[:10]
Out[70]:
[[1.3486991292666062e-06, 13730.0],
 [1.3486991292666062e-06, 21848.0],
 [0.6842671632766724, 8535.0],
 [0.7403992414474487, 19227.0],
 [0.8128081560134888, 23070.0],
 [0.8252330422401428, 11091.0],
 [0.9119076132774353, 18118.0],
 [0.9374194145202637, 24951.0],
 [0.9530797600746155, 5782.0],
 [0.9673166275024414, 17116.0]]

Now we can display all the queries that the model encoded as "similar" to our input query.

In [71]:
for dist, key in distances[:20]:
    key = int(key)
    print(key, dist, [id_to_word[idx] for idx in x_train[key]])
13730 1.3486991292666062e-06 ['best', 'movie', 'ever']
21848 1.3486991292666062e-06 ['best', 'movie', 'ever']
8535 0.6842671632766724 ['modern', 'cinema', 'masterpiece']
19227 0.7403992414474487 ['art', 'form', "'"]
23070 0.8128081560134888 ['family', 'movie', 'night']
11091 0.8252330422401428 ['british', 'crime', 'film']
18118 0.9119076132774353 ['top', 'class', 'cinema']
24951 0.9374194145202637 ['tv', 'show', 'ever']
5782 0.9530797600746155 ['other', 'then', 'entertainment']
17116 0.9673166275024414 ['best', 'films', 'ever']
5420 1.0053369998931885 ['ten', 'movies', 'ever']
21925 1.0175693035125732 ['tv', 'movies', 'go']
24942 1.0481247901916504 ['screen', 'career', 'enjoy']
7274 1.1147147417068481 ['sad', 'sad', 'film']
11157 1.123815655708313 ['christmas', 'entertainment', 'ever']
20455 1.1430972814559937 ['personal', 'level', 'enjoy']
7895 1.2489171028137207 ['best', 'movies', 'ever']
594 1.2569175958633423 ['well', 'made', 'film']
5409 1.2825756072998047 ['indeed', 'enjoyable', 'entertainment']
24181 1.320688009262085 ['production', 'house', 'both']
In [ ]: